# ------------------------------------------------------------------------------
# Preps cohort for analysis
# From: stress_test_medicare repo <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/cabg_outcomes/code/03_analysis/02_prep-holdout/01_prep-holdout.R>
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 05_build_cohort.sh {overnight}
# ------------------------------------------------------------------------------

# Seed ----------------------------------------------------------
set.seed(950169)

# Libraries ----------------------------------------------------------
library(here)
library(yaml)
library(Matrix)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(RSQLite) # get granular time data from master ed encounters
library(glue) # glue strings
library(reticulate) # source_python()
library(lubridate) # force_tz()
library(feather) # read master encounters
library(optparse) # bash options

u <- modules::use(here::here("lib", "util.R"))
source_python(here::here("lib", "python_util.py"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(make_option("--split", type = "character"))
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Config -----------------------------------------------------------------------
message("Configuring...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
prediction_train <- file.path(paths$modeling$oos, "prediction", split, "all")
prediction_dir <- file.path(paths$modeling$dir, "prediction", split, "all")
overnight_lab <- ifelse(arg_list$split == "overnight", "_overnight", "")
cohorts <- list(train = "train", test = "test", val = "val")

if (arg_list$split == "overnight") {
  ids_split <- "overnight_split"
} else {
  ids_split <- "ids_flags"
}

# Load Data --------------------------------------------------------------------
message("Loading cohort...")
ids <- read_csv(paths$cohort[[ids_split]]) %>% select(ed_enc_id, split)
cohort_train <- readRDS(glue(paths$modeling$train))
cohort_val <- readRDS(glue(paths$modeling$val))
cohort_test <- readRDS(glue(paths$modeling$test))

common_cols <- names(cohort_train) %>%
  intersect(names(cohort_val)) %>%
  intersect(names(cohort_test))
cohort_train <- select(cohort_train, all_of(common_cols))
cohort_val <- select(cohort_val, all_of(common_cols))
cohort_test <- select(cohort_test, all_of(common_cols))

cohort <- cohort_train %>%
  rbind(cohort_val) %>%
  rbind(cohort_test)

# Load y-hats ------------------------------------------------------------------
message("Loading scores...")
scores <- readRDS(file.path(prediction_dir, "scores_test_set.rds"))

# OOS predictions for train set
train_scores <- readRDS(file.path(prediction_train, "scores_train_set.rds"))

# Demographics -----------------------------------------------------------------
message("Adding demographic data to cohort...")
dems <- readRDS(paths$analysis$demographics) %>%
  select(ed_enc_id, age_at_admit, dist_BWH_miles, death_030_day)
visits_all <- cohort %>%
  u$safe_left_join(dems)

# ECGS -------------------------------------------------------------------------
message("Adding ecg data to cohort...")
ecgs <- readRDS(paths$analysis$ecgs)
visits_all <- visits_all %>%
  u$safe_left_join(ecgs) %>%
  mutate(
    has_ecg = ifelse(test_010_day, TRUE, has_ecg)
  )

# Exclusions -------------------------------------------------------------------
message("Building universal exclusion variable...")
ami <- readRDS(paths$cohort$ami_encounters)
sameday_tn <- readRDS(paths$analysis$sameday_tn)

visits_all <- visits_all %>%
  u$safe_left_join(ami) %>%
  u$safe_left_join(sameday_tn) %>%
  mutate(
    exclude = (
      excl_flag_c_int | excl_flag_chronic | excl_flag_death |
        age_at_admit >= 80 | (!test_010_day & ami_day_of) |
        (!test_010_day & !(tn_group_sameday == "None" | tn_group_sameday == "0"))
    )
  )

# Time variables  --------------------------------------------------------------
message("Adding time variables...")
visits_all <- visits_all %>%
  mutate(hour = lubridate::hour(start_datetime)) %>%
  mutate(year = lubridate::year(start_datetime) %>% as.factor()) %>%
  mutate(wday = lubridate::wday(start_datetime) %>% as.factor()) %>%
  mutate(week = lubridate::week(start_datetime) %>% as.factor())

# Disch Codes  -----------------------------------------------------------------
message("Creating discharge variables...")
ed_enc <- read_feather(paths$cohort$ed_enc_xwalk) %>%
  setnames("enc_row_id", "ed_enc_id") %>%
  mutate(ptid = as.numeric(ptid))
disch_df <- ed_enc %>%
  select(ed_enc_id, enc_md_id, enc_md_name, discharge_disposition_sup) %>%
  setnames("discharge_disposition_sup", "disch_disp") %>%
  mutate(disch_obs = (disch_disp == "e" | disch_disp == "edob"))

visits_all <- visits_all %>%
  u$safe_left_join(disch_df)

# Use fixed stress/cath codes --------------------------------------------------
message("Adding test outcomes...")
tests <- readRDS(paths$cohort$ed_tests) %>%
  mutate(days_to_cath = (cath_date - start_date) %>% as.numeric()) %>%
  mutate(days_to_stress = (ett_date - start_date) %>% as.numeric()) %>%
  select(
    ed_enc_id, test_010_day, stress_010_day, cath_010_day, days_to_stress,
    days_to_cath
  )

drop_t_vars <- c("test_010_day")
visits_all <- visits_all %>%
  select(-all_of(drop_t_vars)) %>%
  u$safe_left_join(tests) %>%
  mutate(
    first_test = case_when(
      stress_010_day & !cath_010_day ~ "stress",
      !stress_010_day & cath_010_day ~ "cath",
      days_to_stress <= days_to_cath ~ "stress",
      TRUE ~ "cath"
    )
  )

# Add scores to test cohort ----------------------------------------------------
message("Adding scores to test cohort...")
visits <- cohort_test %>%
  select(-all_of(drop_t_vars)) %>%
  u$safe_left_join(visits_all) %>%
  u$safe_left_join(scores) %>%
  u$safe_left_join(ids) %>%
  setDT()

train_visits <- cohort_train %>%
  select(-all_of(drop_t_vars)) %>%
  u$safe_left_join(visits_all) %>%
  u$safe_left_join(train_scores) %>%
  u$safe_left_join(ids) %>%
  setDT()

# Score Tiles ------------------------------------------------------------------
message("Breaking scores into tiles...")
risk_vars <- c("p__ensemble__stent_or_cabg_010_day__tested")
for (risk in risk_vars) {
  for (tile in c(4, 5, 10, 20, 100)) {
    for (population in c("_tested", "_untested", "_all")) {
      message("Risk: ", risk)
      message("Tiles: ", tile)
      message("Population: ", str_replace(population, "_", ""))
      tile_lab <- str_pad(tile, 3, pad = "0")
      outcome <- case_when(
        grepl("stent", risk) ~ "stent_or_cabg",
        grepl("mace", risk) ~ "mace",
        grepl("test", risk) ~ "test",
        TRUE ~ "ERROR"
      )
      dropcc_lab <- ifelse(grepl("dropcc", risk), "_dropcc", "")

      keep_x <- !visits$exclude
      keep_x_train <- !train_visits$exclude
      if (population == "_tested") {
        keep_x <- keep_x & visits$test_010_day
        keep_x_train <- keep_x_train & train_visits$test_010_day
      } else if (population == "_untested") {
        keep_x <- keep_x & !visits$test_010_day
        keep_x_train <- keep_x_train & !train_visits$test_010_day
      }

      tile_var <- glue("tile_{outcome}_{tile_lab}{population}{dropcc_lab}")

      visits[[tile_var]] <- u$ntile_within(
        visits[[risk]], tile,
        which_x = which(keep_x)
      )
      if (risk == "p__ensemble__stent_or_cabg_010_day__tested") {
        train_visits[[tile_var]] <- u$ntile_within(
          train_visits[[risk]], tile,
          which_x = which(keep_x_train)
        )
      }
    }
  }
}

common_cols <- intersect(names(visits), names(train_visits))
# set scores to NA
val_visits <- cohort_val %>%
  select(-all_of(drop_t_vars)) %>%
  u$safe_left_join(visits_all)
for (colname in setdiff(common_cols, names(val_visits))) {
  val_visits[colname] <- NA
}
assert(
  "Names of val_visits df match common_cols",
  setequal(names(val_visits), common_cols)
)
visits_all <- train_visits %>%
  select(all_of(common_cols)) %>%
  rbind(select(visits, all_of(common_cols))) %>%
  rbind(val_visits)

# Save -------------------------------------------------------------------------
message("Saving analysis cohorts...")
write_rds(visits, glue(paths$analysis$test_cohort))
write_rds(visits_all, glue(paths$analysis$full_cohort))

# Done -------------------------------------------------------------------------
message("Done.")
